{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)\n",
    "\n",
    "<img src=\"img/tfidf.gif\" width=\"400\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "from sklearn.utils import shuffle\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "VOCAB_SIZE = 20000\n",
    "N_CLASS = 2\n",
    "BATCH_SIZE = 32\n",
    "N_EPOCH = 2\n",
    "LR = 5e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sparse_tfidf(X):\n",
    "    t0 = time.time()\n",
    "    count = np.zeros((len(X), VOCAB_SIZE))\n",
    "    for i, indices in enumerate(X):\n",
    "        for idx in indices:\n",
    "            count[i, idx] += 1\n",
    "    print(\"%.2f secs ==> Document-Term Matrix\"%(time.time()-t0))\n",
    "\n",
    "    t0 = time.time()\n",
    "    X = TfidfTransformer().fit_transform(count)\n",
    "    print(\"%.2f secs ==> TF-IDF transform\\n\"%(time.time()-t0))\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def next_train_batch(X, y):\n",
    "    for i in range(0, X.shape[0], BATCH_SIZE):\n",
    "        yield X[i : i+BATCH_SIZE].toarray(), y[i: i+BATCH_SIZE]\n",
    "\n",
    "def next_test_batch(X):\n",
    "    for i in range(0, X.shape[0], BATCH_SIZE):\n",
    "        yield X[i : i+BATCH_SIZE].toarray()\n",
    "\n",
    "def train_input_fn(X_train, y_train):\n",
    "    dataset = tf.data.Dataset.from_generator(\n",
    "        lambda: next_train_batch(X_train, y_train), (tf.float32, tf.int64),\n",
    "        (tf.TensorShape([None, VOCAB_SIZE]), tf.TensorShape([None])))\n",
    "    iterator = dataset.make_one_shot_iterator()\n",
    "    return iterator.get_next()\n",
    "\n",
    "def predict_input_fn(X_test):\n",
    "    dataset = tf.data.Dataset.from_generator(\n",
    "        lambda: next_test_batch(X_test), tf.float32,\n",
    "        tf.TensorShape([None, VOCAB_SIZE]))\n",
    "    iterator = dataset.make_one_shot_iterator()\n",
    "    return iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn(features, labels, mode, params):\n",
    "    logits = tf.layers.dense(features, N_CLASS)\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        return tf.estimator.EstimatorSpec(mode, predictions=tf.argmax(logits, -1))\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=labels))\n",
    "\n",
    "        train_op = tf.train.AdamOptimizer(LR).minimize(loss_op,\n",
    "            global_step=tf.train.get_global_step())\n",
    "        \n",
    "        return tf.estimator.EstimatorSpec(mode=mode, loss=loss_op, train_op=train_op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.47 secs ==> Document-Term Matrix\n",
      "11.38 secs ==> TF-IDF transform\n",
      "\n",
      "3.51 secs ==> Document-Term Matrix\n",
      "9.38 secs ==> TF-IDF transform\n",
      "\n",
      "INFO:tensorflow:Using default config.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i\n",
      "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x1185e2780>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "WARNING:tensorflow:Estimator's model_fn (<function model_fn at 0x107a2c598>) includes params argument, but params are not passed to Estimator.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i/model.ckpt.\n",
      "INFO:tensorflow:loss = 0.69590914, step = 1\n",
      "INFO:tensorflow:global_step/sec: 359.903\n",
      "INFO:tensorflow:loss = 0.5569855, step = 101 (0.279 sec)\n",
      "INFO:tensorflow:global_step/sec: 386.986\n",
      "INFO:tensorflow:loss = 0.53416413, step = 201 (0.258 sec)\n",
      "INFO:tensorflow:global_step/sec: 383.917\n",
      "INFO:tensorflow:loss = 0.4360689, step = 301 (0.261 sec)\n",
      "INFO:tensorflow:global_step/sec: 381.126\n",
      "INFO:tensorflow:loss = 0.40131736, step = 401 (0.262 sec)\n",
      "INFO:tensorflow:global_step/sec: 385.489\n",
      "INFO:tensorflow:loss = 0.30325773, step = 501 (0.259 sec)\n",
      "INFO:tensorflow:global_step/sec: 385.486\n",
      "INFO:tensorflow:loss = 0.35899192, step = 601 (0.260 sec)\n",
      "INFO:tensorflow:global_step/sec: 379.976\n",
      "INFO:tensorflow:loss = 0.41224274, step = 701 (0.263 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 782 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.19030645.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i/model.ckpt-782\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Validation Accuracy: 0.8847\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i/model.ckpt-782\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 783 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i/model.ckpt.\n",
      "INFO:tensorflow:loss = 0.20300367, step = 783\n",
      "INFO:tensorflow:global_step/sec: 358.019\n",
      "INFO:tensorflow:loss = 0.21820101, step = 883 (0.280 sec)\n",
      "INFO:tensorflow:global_step/sec: 340.944\n",
      "INFO:tensorflow:loss = 0.3103854, step = 983 (0.293 sec)\n",
      "INFO:tensorflow:global_step/sec: 375.717\n",
      "INFO:tensorflow:loss = 0.20317315, step = 1083 (0.266 sec)\n",
      "INFO:tensorflow:global_step/sec: 366.491\n",
      "INFO:tensorflow:loss = 0.3555227, step = 1183 (0.273 sec)\n",
      "INFO:tensorflow:global_step/sec: 383.751\n",
      "INFO:tensorflow:loss = 0.27495146, step = 1283 (0.260 sec)\n",
      "INFO:tensorflow:global_step/sec: 389.082\n",
      "INFO:tensorflow:loss = 0.29025418, step = 1383 (0.257 sec)\n",
      "INFO:tensorflow:global_step/sec: 371.556\n",
      "INFO:tensorflow:loss = 0.18338819, step = 1483 (0.270 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1564 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.22793391.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp_0eyd01i/model.ckpt-1564\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Validation Accuracy: 0.8902\n",
      "\n"
     ]
    }
   ],
   "source": [
    "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)\n",
    "X_train = sparse_tfidf(X_train)\n",
    "X_test = sparse_tfidf(X_test)\n",
    "\n",
    "estimator = tf.estimator.Estimator(model_fn)\n",
    "\n",
    "for _ in range(N_EPOCH):\n",
    "    estimator.train(lambda: train_input_fn(*shuffle(X_train, y_train)))\n",
    "    y_pred = np.fromiter(estimator.predict(lambda: predict_input_fn(X_test)), np.int32)\n",
    "    print(\"\\nValidation Accuracy: %.4f\\n\" % (y_pred==y_test).mean())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
